#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: liang kang
@contact: gangkanli1219@gmail.com
@time: 2018/3/19 13:49
@desc: 
"""
import numpy as np


def check_image_format(image):
    """
    检测图像的格式，如果是 [height, width, channels]，则返回原图
    [height, width]，则返回 [height, width, 1] 与还原方法
    [channels, height, width]， 在返回 [height, width, channels] 和还原方法
    Parameters
    ----------
    image

    Returns
    -------

    """
    shape = image.shape
    if 2 == len(shape):
        image = np.expand_dims(image, 2)
        return image, lambda img: np.squeeze(img)
    elif 3 == len(shape):
        if 3 == shape[0]:
            image = np.transpose(image, [1, 2, 0])
            return image, lambda img: np.transpose(img, [2, 0, 1])
        else:
            return image, lambda img: img
